Spark - take() 算子

以后遇到不懂的 Spark 算子的话,我都尽可能以笔记的方式去记录它

遇到的情况

1
2
3
rdd.map(...)	 // 重要前提:数据量在 TB 级别
.filter(...) // 根据某些条件筛选数据
.take(100000) // 取当前数据的前十万条

当时的程序大致就是这样,我的想法是根据filter()之后的数据直接利用take()拿前十万的数据,感觉方便又省事,但是实际的运行情况却是作业的运行时间很长,让人怀疑人生。而且take()一开始默认的分区是1,而后如果当前任务失败的话,会适当的扩增分区数来读取更多的数据。

源码分析

废话不多,先贴源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
/**
* Take the first num elements of the RDD. It works by first scanning one partition, and use the results from that partition to estimate the number of additional partitions needed to satisfy the limit.
*
* @note This method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory.
* 此方法被使用时期望目标数组的大小比较小,即其数组中所有数据都能够存储在 driver 的内存当中。这里的函数解释当中提及到了处理的数据量应当较小,但是没说如果处理了比较大的数据时会怎么样,还得看看继续往下看
*
* @note Due to complications in the internal implementation, this method will raise
* an exception if called on an RDD of `Nothing` or `Null`.
*/
def take(num: Int): Array[T] = withScope {
// scaleUpFactor 字面意思是扩增因子,看到这里我们可以结合上图的例子,不难看出分区的扩增是按照一定的倍数增长的
val scaleUpFactor = Math.max(conf.getInt("spark.rdd.limit.scaleUpFactor", 4), 2)
if (num == 0) {
new Array[T](0)
} else {
val buf = new ArrayBuffer[T]
val totalParts = this.partitions.length
var partsScanned = 0

// 这个循环是为什么 take 失败会进行重试的关键
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
// numPartsToTry - 此次循环迭代的分区个数,默认为1。
var numPartsToTry = 1L
val left = num - buf.size
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
// 重点!当在上一次迭代当中,我们没有找到任何满足条件的 row 时(至少是不满足指定数量时),有规律的重试(quadruple and retry,翻译水平有限)
if (buf.isEmpty) {
numPartsToTry = partsScanned * scaleUpFactor
} else {
// As left > 0, numPartsToTry is always >= 1
numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor)
}
}

val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)

// 每一次循环迭代都会获取新的数据加到 buf 当中,所以并不是每一次重试都是从头对数据进行遍历,那这样会没完没了
res.foreach(buf ++= _.take(num - buf.size))
partsScanned += p.size
}

// 这里才是我们最终的结果
buf.toArray
}
}

总结

take()算子使用的场景是当数据量规模较小的情况,亦或者说搭配filter()时,filter()能够较快的筛选出数据来。